package com.jstarcraft.ai.jsat.clustering.hierarchical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.SimpleDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.clustering.ClusterFailureException;
import com.jstarcraft.ai.jsat.clustering.KClusterer;
import com.jstarcraft.ai.jsat.clustering.KClustererBase;
import com.jstarcraft.ai.jsat.clustering.evaluation.ClusterEvaluation;

/**
 * DivisiveGlobalClusterer is a hierarchical clustering method that works by
 * splitting the data set into sub trees from the top down. Unlike many top-up
 * methods, such as {@link SimpleHAC}, top-down methods require another
 * clustering method to perform the splitting at each iteration. If the base
 * method is not deterministic, then the top-down method will not be
 * deterministic. <br>
 * Like many HAC methods, DivisiveGlobalClusterer will store the merge order of
 * the clusters so that the clustering results for many <i>k</i> can be
 * obtained. It is limited to the range of clusters successfully computed
 * before. <br>
 * <br>
 * Specifically, DivisiveGlobalClusterer greedily chooses the cluster to split
 * based on an evaluation of all resulting clusters after a split. Because of
 * this global search of the world, DivisiveLocalClusterer has can make a good
 * estimate of the number of clusters in the data set. The quality of this
 * result is dependent on the accuracy of the {@link ClusterEvaluation} used.
 * This quality comes at the cost of execution speed, as more and more large
 * evaluations of the whole dataset are needed at each iteration. If execution
 * speed is more important, {@link DivisiveLocalClusterer} should be used
 * instead, which requires only a fixed number of evaluations per iteration.
 * 
 * @author Edward Raff
 */
public class DivisiveGlobalClusterer extends KClustererBase {

    private static final long serialVersionUID = -9117751530105155090L;
    private KClusterer baseClusterer;
    private ClusterEvaluation clusterEvaluation;

    private int[] splitList;
    private int[] fullDesignations;
    private DataSet originalDataSet;

    public DivisiveGlobalClusterer(KClusterer baseClusterer, ClusterEvaluation clusterEvaluation) {
        this.baseClusterer = baseClusterer;
        this.clusterEvaluation = clusterEvaluation;
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public DivisiveGlobalClusterer(DivisiveGlobalClusterer toCopy) {
        this.baseClusterer = toCopy.baseClusterer.clone();
        this.clusterEvaluation = toCopy.clusterEvaluation.clone();
        if (toCopy.splitList != null)
            this.splitList = Arrays.copyOf(toCopy.splitList, toCopy.splitList.length);
        if (toCopy.fullDesignations != null)
            this.fullDesignations = Arrays.copyOf(toCopy.fullDesignations, toCopy.fullDesignations.length);
        this.originalDataSet = toCopy.originalDataSet.shallowClone();
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return cluster(dataSet, 2, (int) Math.sqrt(dataSet.size()), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        return cluster(dataSet, clusters, clusters, parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        if (designations == null)
            designations = new int[dataSet.size()];
        /**
         * Is used to copy the value of designations and then alter to test the quality
         * of a potential new clustering
         */
        int[] fakeWorld = new int[dataSet.size()];

        /**
         * For each current cluster, we store the clustering results if we attempt to
         * split it into two. <br>
         * Each row needs to be re-set since the clustering methods will use the length
         * of the cluster size
         */
        final int[][] subDesignation = new int[highK][];
        /**
         * Stores the index from the sub data set into the full data set
         */
        final int[][] originalPositions = new int[highK][dataSet.size()];
        /**
         * List of Lists for holding the data points of each cluster in
         */
        List<List<DataPoint>> pointsInCluster = new ArrayList<>(highK);
        for (int i = 0; i < highK; i++)
            pointsInCluster.add(new ArrayList<>(dataSet.size()));

        /**
         * Stores the dissimilarity of the splitting of the cluster with the same index
         * value. Negative value indicates not set. Special values:<br>
         * <ul>
         * <li>NEGATIVE_INFINITY : value never used</li>
         * <li>-1 : clustering computed, but no current global evaluation</li>
         * <li>>=0 : cluster computed, current value is the evaluation for using this
         * split</li>
         * </ul>
         */
        final double[] splitEvaluation = new double[highK];
        Arrays.fill(splitEvaluation, Double.NEGATIVE_INFINITY);

        /**
         * Records the order in which items were split
         */
        splitList = new int[highK * 2 - 2];
        int bestK = -1;
        double bestKEval = Double.POSITIVE_INFINITY;

        // k is the current number of clusters, & the ID of the next cluster
        for (int k = 1; k < highK; k++) {
            double bestSplitVal = Double.POSITIVE_INFINITY;
            int bestID = -1;

            for (int z = 0; z < k; z++)// TODO it might be better to do this loop in parallel
            {
                if (Double.isNaN(splitEvaluation[z]))
                    continue;
                else if (splitEvaluation[z] == Double.NEGATIVE_INFINITY)// at most 2 will hit this per loop
                {// Need to compute a split for that cluster & set up helper structures
                    List<DataPoint> clusterPointsZ = pointsInCluster.get(z);
                    clusterPointsZ.clear();
                    for (int i = 0; i < dataSet.size(); i++) {
                        if (designations[i] != z)
                            continue;
                        originalPositions[z][clusterPointsZ.size()] = i;
                        clusterPointsZ.add(dataSet.getDataPoint(i));
                    }
                    subDesignation[z] = new int[clusterPointsZ.size()];
                    if (clusterPointsZ.isEmpty())// Empty cluster? How did that happen...
                    {
                        splitEvaluation[z] = Double.NaN;
                        continue;
                    }
                    SimpleDataSet subDataSet = new SimpleDataSet(clusterPointsZ);

                    try {
                        baseClusterer.cluster(subDataSet, 2, parallel, subDesignation[z]);
                    } catch (ClusterFailureException ex) {
                        splitEvaluation[z] = Double.NaN;
                        continue;
                    }
                }

                System.arraycopy(designations, 0, fakeWorld, 0, fakeWorld.length);
                for (int i = 0; i < subDesignation[z].length; i++)
                    if (subDesignation[z][i] == 1)
                        fakeWorld[originalPositions[z][i]] = k;
                try {
                    splitEvaluation[z] = clusterEvaluation.evaluate(fakeWorld, dataSet);
                } catch (Exception ex)// Can occur if one of the clusters has size zeros
                {
                    splitEvaluation[z] = Double.NaN;
                    continue;
                }

                if (splitEvaluation[z] < bestSplitVal) {
                    bestSplitVal = splitEvaluation[z];
                    bestID = z;
                }
            }

            // We now know which cluster we should use the split of
            for (int i = 0; i < subDesignation[bestID].length; i++)
                if (subDesignation[bestID][i] == 1)
                    designations[originalPositions[bestID][i]] = k;

            // The original clsuter id, and the new one should be set to -Inf
            splitEvaluation[bestID] = splitEvaluation[k] = Double.NEGATIVE_INFINITY;

            // Store a split list
            splitList[(k - 1) * 2] = bestID;
            splitList[(k - 1) * 2 + 1] = k;
            if (lowK - 1 <= k && k <= highK - 1)// Should we stop?
            {
                if (bestSplitVal < bestKEval) {
                    bestKEval = bestSplitVal;
                    bestK = k;
//                    System.out.println("Best k is now " + k + " at " + bestKEval);
                }
            }
        }

        fullDesignations = Arrays.copyOf(designations, designations.length);

        // Merge the split clusters back to the one that had the best score
        for (int k = splitList.length / 2 - 1; k >= bestK; k--) {
            if (splitList[k * 2] == splitList[k * 2 + 1])
                continue;// Happens when we bail out early
            for (int j = 0; j < designations.length; j++)
                if (designations[j] == splitList[k * 2 + 1])
                    designations[j] = splitList[k * 2];
        }

        originalDataSet = dataSet;
        return designations;
    }

    /**
     * Returns the clustering results for a specific <i>k</i> number of clusters for
     * a previously computed data set. If the data set did not compute up to the
     * value <i>k</i> <tt>null</tt> will be returned.
     * 
     * @param targetK the number of clusters to get the result for.
     * @return an array containing the assignments for each cluster in the original
     *         data set.
     * @throws ClusterFailureException if no prior data set had been clustered
     */
    public int[] clusterSplit(int targetK) {
        if (originalDataSet == null)
            throw new ClusterFailureException("No prior cluster stored");
        int[] newDesignations = Arrays.copyOf(fullDesignations, fullDesignations.length);
        // Merge the split clusters back to the one that had the best score
        for (int k = splitList.length / 2 - 1; k >= targetK; k--) {
            if (splitList[k * 2] == splitList[k * 2 + 1])
                continue;// Happens when we bail out early
            for (int j = 0; j < newDesignations.length; j++)
                if (newDesignations[j] == splitList[k * 2 + 1])
                    newDesignations[j] = splitList[k * 2];
        }
        return newDesignations;
    }

    @Override
    public DivisiveGlobalClusterer clone() {
        return new DivisiveGlobalClusterer(this);
    }

}
